import os
import json
import numpy as np
from PIL import Image
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset, DataLoader
from paddle.vision import models
import paddlenlp as ppnlp
from paddlenlp.transformers import ErnieTokenizer, ErnieModel

# 设置随机种子，确保结果可复现
paddle.seed(42)
np.random.seed(42)

# ---------------------- 1. 数据集定义 ----------------------
class MultiModalDataset(Dataset):
    """多模态图像-文本分类数据集"""
    def __init__(self, data_path, image_dir, tokenizer, max_seq_len=128, mode='train'):
        """
        data_path: 标注文件路径
        image_dir: 图像文件夹路径
        tokenizer: 文本tokenizer
        max_seq_len: 文本最大长度
        mode: 模式，train/val/test
        """
        super().__init__()
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.mode = mode
        
        # 加载数据集
        with open(data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
            
        # 定义类别到ID的映射（根据数据集调整）
        self.label2id = {
            '科技': 0, '娱乐': 1, '体育': 2, '财经': 3, '教育': 4
        }
        self.id2label = {v: k for k, v in self.label2id.items()}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # 获取单条数据
        item = self.data[idx]
        image_path = os.path.join(self.image_dir, item['image'])
        text = item['text']
        label = self.label2id[item['label']]
        
        # 处理图像
        image = Image.open(image_path).convert('RGB')
        image = self._preprocess_image(image)
        
        # 处理文本
        encoded_inputs = self.tokenizer(
            text=text,
            max_seq_len=self.max_seq_len,
            pad_to_max_seq_len=True,
            return_attention_mask=True,
            return_token_type_ids=True
        )
        
        # 转换为Tensor
        input_ids = paddle.to_tensor(encoded_inputs['input_ids'], dtype='int64')
        attention_mask = paddle.to_tensor(encoded_inputs['attention_mask'], dtype='int64')
        token_type_ids = paddle.to_tensor(encoded_inputs['token_type_ids'], dtype='int64')
        label = paddle.to_tensor(label, dtype='int64')
        
        return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'label': label
        }
    
    def _preprocess_image(self, image):
        """图像预处理：缩放、归一化、转Tensor"""
        # 调整图像大小为224x224
        image = image.resize((224, 224), Image.BICUBIC)
        # 转换为numpy数组
        image = np.array(image).astype('float32')
        # 归一化
        image = image / 255.0
        # 标准化（ImageNet均值和标准差）
        image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
        # 调整通道顺序 (HWC -> CHW)
        image = np.transpose(image, (2, 0, 1))
        return paddle.to_tensor(image, dtype='float32')

# ---------------------- 2. 多模态分类模型 ----------------------
class MultiModalClassifier(nn.Layer):
    """基于图像和文本的多模态分类模型"""
    def __init__(self, num_classes, text_encoder='ernie-1.0', pretrained=True):
        super().__init__()
        
        # 图像编码器（使用预训练ResNet50）
        self.image_encoder = models.resnet50(pretrained=pretrained)
        # 移除最后的全连接层
        self.image_encoder.fc = nn.Identity()
        # 添加投影层，将图像特征映射到共同空间
        self.image_proj = nn.Linear(2048, 512)
        
        # 文本编码器（使用预训练ERNIE）
        self.text_encoder = ErnieModel.from_pretrained(text_encoder)
        # 添加投影层，将文本特征映射到共同空间
        self.text_proj = nn.Linear(768, 512)
        
        # 特征融合层
        self.fusion = nn.Sequential(
            nn.Linear(1024, 512),  # 拼接图像和文本特征 (512+512)
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        # 分类器
        self.classifier = nn.Linear(256, num_classes)
        
    def forward(self, image, input_ids, attention_mask, token_type_ids=None):
        # 提取图像特征
        image_features = self.image_encoder(image)  # [batch_size, 2048]
        image_features = self.image_proj(image_features)  # [batch_size, 512]
        
        # 提取文本特征
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        # 获取[CLS] token的表示
        text_features = text_outputs[1]  # [batch_size, 768]
        text_features = self.text_proj(text_features)  # [batch_size, 512]
        
        # 特征融合
        fused_features = paddle.concat([image_features, text_features], axis=1)  # [batch_size, 1024]
        fused_features = self.fusion(fused_features)  # [batch_size, 256]
        
        # 分类预测
        logits = self.classifier(fused_features)  # [batch_size, num_classes]
        
        return logits

# ---------------------- 3. 模型训练 ----------------------
def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, save_dir):
    """训练多模态分类模型"""
    best_acc = 0.0
    
    for epoch in range(epochs):
        # 训练模式
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        for batch in train_loader:
            # 获取数据
            image = batch['image']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            token_type_ids = batch['token_type_ids']
            label = batch['label']
            
            # 前向传播
            logits = model(image, input_ids, attention_mask, token_type_ids)
            loss = criterion(logits, label)
            
            # 反向传播
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            
            # 统计训练指标
            train_loss += loss.numpy()[0]
            total += label.shape[0]
            pred = paddle.argmax(logits, axis=1)
            correct += (pred == label).sum().numpy()[0]
        
        # 计算训练准确率
        train_acc = correct / total
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}')
        
        # 验证
        val_acc = evaluate_model(model, val_loader)
        print(f'Epoch [{epoch+1}/{epochs}], Val Acc: {val_acc:.4f}')
        
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            paddle.save(model.state_dict(), os.path.join(save_dir, 'best_model.pdparams'))
            print(f'Model saved at acc: {best_acc:.4f}')

# ---------------------- 4. 模型评估 ----------------------
def evaluate_model(model, data_loader):
    """评估模型性能"""
    model.eval()
    correct = 0
    total = 0
    
    with paddle.no_grad():
        for batch in data_loader:
            # 获取数据
            image = batch['image']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            token_type_ids = batch['token_type_ids']
            label = batch['label']
            
            # 模型预测
            logits = model(image, input_ids, attention_mask, token_type_ids)
            pred = paddle.argmax(logits, axis=1)
            
            # 统计准确率
            total += label.shape[0]
            correct += (pred == label).sum().numpy()[0]
    
    return correct / total

# ---------------------- 5. 主函数 ----------------------
def main():
    # 配置参数
    config = {
        'train_data_path': 'data/train.json',  # 训练数据路径
        'val_data_path': 'data/val.json',      # 验证数据路径
        'image_dir': 'data/images',            # 图像文件夹路径
        'save_dir': 'checkpoints',             # 模型保存路径
        'num_classes': 5,                      # 分类类别数
        'batch_size': 16,                      # 批次大小
        'epochs': 10,                          # 训练轮数
        'learning_rate': 1e-4,                 # 学习率
        'max_seq_len': 128                     # 文本最大长度
    }
    
    # 创建保存目录
    os.makedirs(config['save_dir'], exist_ok=True)
    
    # 初始化tokenizer
    tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
    
    # 创建数据集
    train_dataset = MultiModalDataset(
        config['train_data_path'], 
        config['image_dir'], 
        tokenizer, 
        config['max_seq_len'],
        mode='train'
    )
    
    val_dataset = MultiModalDataset(
        config['val_data_path'], 
        config['image_dir'], 
        tokenizer, 
        config['max_seq_len'],
        mode='val'
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=4
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=4
    )
    
    # 初始化模型
    model = MultiModalClassifier(config['num_classes'])
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = paddle.optimizer.AdamW(
        learning_rate=config['learning_rate'],
        parameters=model.parameters()
    )
    
    # 训练模型
    train_model(model, train_loader, val_loader, optimizer, criterion, config['epochs'], config['save_dir'])
    
    # 加载最佳模型并评估
    model.set_state_dict(paddle.load(os.path.join(config['save_dir'], 'best_model.pdparams')))
    test_acc = evaluate_model(model, val_loader)
    print(f'Final Test Accuracy: {test_acc:.4f}')

if __name__ == '__main__':
    main()